# Rdkit import should be first, do not move it
try:
    from rdkit import Chem
except ModuleNotFoundError:
    pass

from utils import utils
from configs.parse_args import eval_sample_parse_args
from qm9 import dataset as qmds
from dadm.get_models import get_da_diffusion

from utils.utils import assert_correctly_masked
import torch
import pickle
import qm9.visualizer as vis
from qm9.analyze import check_stability
from os.path import join
from traintest.sampling import sample_chain, sample
from configs.datasets_config import get_dataset_info


def check_mask_correct(variables, node_mask):
    for variable in variables:
        assert_correctly_masked(variable, node_mask)


def save_and_sample_chain(args, eval_args, device, flow, target_data_loader,
                          n_tries, n_nodes, dataset_info, id_from=0,
                          num_chains=100):
    for i in range(num_chains):
        target_data = next(iter(target_data_loader))
        target_path = f'sample_{eval_args.target_domain}/chain_{i}/'

        one_hot, charges, x = sample_chain(
            args, device, flow, target_data, n_tries, dataset_info)

        vis.save_xyz_file(
            join(eval_args.model_path, target_path), one_hot, charges, x,
            dataset_info, id_from, name='chain')

        vis.visualize_chain_uncertainty(
            join(eval_args.model_path, target_path), dataset_info,
            spheres_3d=True)

        # Save target molecules
        target_x = target_data['positions']
        target_node_mask = target_data['atom_mask']
        target_one_hot = target_data['one_hot'].int()
        target_charges = target_data['charges']
        vis.save_xyz_file(
            join(eval_args.model_path, target_path),
            target_one_hot, target_charges, target_x, dataset_info, id_from, name='target_molecule', node_mask=target_node_mask)

        with open(join(eval_args.model_path, target_path) + 'target_molecule.pickle',
                  "wb") as file:
            # Dump the data into the file
            pickle.dump(target_data, file)

        generated_molecule= {'one_hot': one_hot, 'x': x, 'charges': charges}
        with open(join(eval_args.model_path, target_path) + 'generated_molecule.pickle',
                  "wb") as file:
            # Dump the data into the file
            pickle.dump(generated_molecule, file)




    return one_hot, charges, x


def sample_different_sizes_and_save(args, eval_args, device, generative_model, target_data,
                                    nodes_dist, dataset_info, n_samples=10, id_from=0):
    nodesxsample = nodes_dist.sample(n_samples)
    one_hot, charges, x, node_mask = sample(
        args, device, generative_model, target_data, dataset_info,
        nodesxsample=nodesxsample)

    vis.save_xyz_file(
        join(eval_args.model_path, f'sample_{eval_args.target_domain}/molecules/'), one_hot, charges, x,
        id_from=id_from, name='molecule', dataset_info=dataset_info,
        node_mask=node_mask)


def sample_only_stable_different_sizes_and_save(
        args, eval_args, device, generative_dm, target_data, nodes_dist,
        dataset_info, n_samples=10, n_tries=50):
    assert n_tries > n_samples

    nodesxsample = nodes_dist.sample(n_tries)
    one_hot, charges, x, node_mask = sample(
        args, device, generative_dm, target_data, dataset_info,
        nodesxsample=nodesxsample)

    counter = 0
    for i in range(n_tries):
        num_atoms = int(node_mask[i:i+1].sum().item())
        atom_type = one_hot[i:i+1, :num_atoms].argmax(2).squeeze(0).cpu().detach().numpy()
        x_squeeze = x[i:i+1, :num_atoms].squeeze(0).cpu().detach().numpy()
        mol_stable = check_stability(x_squeeze, atom_type, dataset_info)[0]

        num_remaining_attempts = n_tries - i - 1
        num_remaining_samples = n_samples - counter

        if mol_stable or num_remaining_attempts <= num_remaining_samples:
            if mol_stable:
                print('Found stable mol.')
            vis.save_xyz_file(
                join(eval_args.model_path, f'sample_{eval_args.target_domain}/molecules/'),
                one_hot[i:i+1], charges[i:i+1], x[i:i+1],
                id_from=counter, name='molecule_stable',
                dataset_info=dataset_info,
                node_mask=node_mask[i:i+1])
            counter += 1

            if counter >= n_samples:
                break


def main():
    parser = eval_sample_parse_args()
    eval_args, unparsed_args = parser.parse_known_args()

    assert eval_args.model_path is not None

    with open(join(eval_args.model_path, 'args.pickle'), 'rb') as f:
        args = pickle.load(f)

    # CAREFUL with this -->
    if not hasattr(args, 'normalization_factor'):
        args.normalization_factor = 1
    if not hasattr(args, 'aggregation_method'):
        args.aggregation_method = 'sum'

    args.cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda" if args.cuda else "cpu")
    args.device = device
    dtype = torch.float32
    utils.create_folders(args)
    print(args)

    try:
        if eval_args.dataset != 'None':
            print('Load target dataset from ', eval_args.dataset)
            args.is_da_mg = True
            args.dataset = eval_args.dataset
    except:
        print('Load dataset from ', args.dataset)

    # Retrieve QM9 dataloaders
    args.batch_size = eval_args.batch_size_gen
    dataloaders, charge_scale = qmds.retrieve_dataloaders(args)
    if args.pas_data:
        args.ood_element_size = dataloaders['train'].dataset.num_node_features
        n_nodes_info = None
        dataset_info = None
    else:
        dataset_info = get_dataset_info(args.dataset, args.remove_h)
        args.ood_element_size = len(dataset_info['atom_decoder'])
        n_nodes_info = dataset_info['n_nodes']

    generative_model, nodes_dist, prop_dist = get_da_diffusion(args, device, dataloaders, n_nodes_info)
    # Prepare target mol_data
    print(f'use {eval_args.target_domain} data as condition')
    target_data_loader = dataloaders[eval_args.target_domain]
    target_data = next(iter(target_data_loader))

    # generative_model, nodes_dist, prop_dist = get_latent_diffusion(
    #     args, device, dataset_info, dataloaders['train'])
    generative_model.to(device)

    fn = 'generative_model_ema.npy' if args.ema_decay > 0 else 'generative_model.npy'
    da_state_dict = torch.load(join(eval_args.model_path, fn),
                                 map_location=device)

    generative_model.load_state_dict(da_state_dict)

    print('Sampling handful of molecules.')
    if target_data['positions'].size(0) < eval_args.batch_size_gen:
        sample_batch = target_data['positions'].size(0)
        for i in range(int(eval_args.batch_size_gen/ sample_batch)):
            sample_different_sizes_and_save(
                args, eval_args, device, generative_model, target_data, nodes_dist,
                dataset_info=dataset_info, n_samples=sample_batch, id_from=i*sample_batch)
    else:
        sample_different_sizes_and_save(
            args, eval_args, device, generative_model, target_data, nodes_dist,
            dataset_info=dataset_info, n_samples=eval_args.batch_size_gen)

    eval_args.batch_size_gen = min(target_data['positions'].size(0), eval_args.batch_size_gen)
    print('Sampling stable molecules.')
    sample_only_stable_different_sizes_and_save(
        args, eval_args, device, generative_model, target_data, nodes_dist,
        dataset_info=dataset_info, n_samples=eval_args.batch_size_gen / 2, n_tries=eval_args.batch_size_gen)
    print('Visualizing molecules.')
    vis.visualize(
        join(eval_args.model_path, f'sample_{eval_args.target_domain}/molecules/'), dataset_info,
        max_num=100, spheres_3d=True)

    args.batch_size = 1
    dataloaders = all_datasets.get_data_loader(args.dataset, args)
    # Prepare target mol_data
    print(f'use {eval_args.target_domain} data as condition')
    target_data_loader = dataloaders[eval_args.target_domain]
    print('Sampling visualization chain.')
    save_and_sample_chain(
        args, eval_args, device, generative_model, target_data_loader,
        n_tries=eval_args.n_tries, n_nodes=eval_args.n_nodes,
        dataset_info=dataset_info, num_chains=5)


if __name__ == "__main__":
    main()
